- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
[MacroFusion] Support commutable instructions #82751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1b286ff    to
    9d9a50d      
    Compare
  
    9d9a50d    to
    048d1bd      
    Compare
  
    048d1bd    to
    74c60af      
    Compare
  
    | @llvm/pr-subscribers-backend-risc-v Author: Wang Pengcheng (wangpc-pp) ChangesIf the second instruction is commutable, we should be able to check A simple RISCV fusion is contained in this PR to show the functionality There are some other issues I should fix. For example, we should be Fixes #82738 Full diff: https://github.com/llvm/llvm-project/pull/82751.diff 5 Files Affected: 
 diff --git a/llvm/include/llvm/Target/TargetSchedule.td b/llvm/include/llvm/Target/TargetSchedule.td
index 48c9387977c075..a872cc30a9b50c 100644
--- a/llvm/include/llvm/Target/TargetSchedule.td
+++ b/llvm/include/llvm/Target/TargetSchedule.td
@@ -617,16 +617,27 @@ class SecondFusionPredicateWithMCInstPredicate<MCInstPredicate pred>
   : FusionPredicateWithMCInstPredicate<second_fusion_target, pred>;
 // The pred will be applied on both firstMI and secondMI.
 class BothFusionPredicateWithMCInstPredicate<MCInstPredicate pred>
-  : FusionPredicateWithMCInstPredicate<second_fusion_target, pred>;
+  : FusionPredicateWithMCInstPredicate<both_fusion_target, pred>;
 
 // Tie firstOpIdx and secondOpIdx. The operand of `FirstMI` at position
 // `firstOpIdx` should be the same as the operand of `SecondMI` at position
 // `secondOpIdx`.
+// If the operand at `secondOpIdx` has commutable operand, then the commutable
+// operand will be checked too.
 class TieReg<int firstOpIdx, int secondOpIdx> : BothFusionPredicate {
   int FirstOpIdx = firstOpIdx;
   int SecondOpIdx = secondOpIdx;
 }
 
+// The operand of `SecondMI` at position `firstOpIdx` should be the same as the
+// operand at position `secondOpIdx`.
+// If the operand at `secondOpIdx` has commutable operand, then the commutable
+// operand will be checked too.
+class SameReg<int firstOpIdx, int secondOpIdx> : SecondFusionPredicate {
+  int FirstOpIdx = firstOpIdx;
+  int SecondOpIdx = secondOpIdx;
+}
+
 // A predicate for wildcard. The generated code will be like:
 // ```
 // if (!FirstMI)
@@ -688,11 +699,7 @@ class SimpleFusion<string name, string fieldName, string desc,
                 SecondFusionPredicateWithMCInstPredicate<secondPred>,
                 WildcardTrue,
                 FirstFusionPredicateWithMCInstPredicate<firstPred>,
-                SecondFusionPredicateWithMCInstPredicate<
-                  CheckAny<[
-                    CheckIsVRegOperand<0>,
-                    CheckSameRegOperand<0, 1>
-                  ]>>,
+                SameReg<0, 1>,
                 OneUse,
                 TieReg<0, 1>,
               ],
diff --git a/llvm/lib/Target/RISCV/RISCVMacroFusion.td b/llvm/lib/Target/RISCV/RISCVMacroFusion.td
index 875a93d09a2c64..14e8962f8ce110 100644
--- a/llvm/lib/Target/RISCV/RISCVMacroFusion.td
+++ b/llvm/lib/Target/RISCV/RISCVMacroFusion.td
@@ -91,3 +91,24 @@ def TuneLDADDFusion
                    CheckIsImmOperand<2>,
                    CheckImmOperand<2, 0>
                  ]>>;
+
+// These should be covered by Zba extension.
+// * shift left by one and add:
+//   slli r1, r0, 1
+//   add r1, r1, r2
+// * shift left by two and add:
+//   slli r1, r0, 2
+//   add r1, r1, r2
+// * shift left by three and add:
+//   slli r1, r0, 3
+//   add r1, r1, r2
+def ShiftNAddFusion
+  : SimpleFusion<"shift-n-add-fusion", "HasShiftNAddFusion",
+                 "Enable SLLI+ADD to be fused to shift left by 1/2/3 and add",
+                 CheckAll<[
+                   CheckOpcode<[SLLI]>,
+                   CheckAny<[CheckImmOperand<2, 1>,
+                             CheckImmOperand<2, 2>,
+                             CheckImmOperand<2, 3>]>
+                 ]>,
+                 CheckOpcode<[ADD]>>;
diff --git a/llvm/test/CodeGen/RISCV/macro-fusions.mir b/llvm/test/CodeGen/RISCV/macro-fusions.mir
index 13464141ce27e6..11bb456e0ca050 100644
--- a/llvm/test/CodeGen/RISCV/macro-fusions.mir
+++ b/llvm/test/CodeGen/RISCV/macro-fusions.mir
@@ -1,7 +1,7 @@
 # REQUIRES: asserts
 # RUN: llc -mtriple=riscv64-linux-gnu -x=mir < %s \
 # RUN:   -debug-only=machine-scheduler -start-before=machine-scheduler 2>&1 \
-# RUN:   -mattr=+lui-addi-fusion,+auipc-addi-fusion,+zexth-fusion,+zextw-fusion,+shifted-zextw-fusion,+ld-add-fusion \
+# RUN:   -mattr=+lui-addi-fusion,+auipc-addi-fusion,+zexth-fusion,+zextw-fusion,+shifted-zextw-fusion,+ld-add-fusion,+shift-n-add-fusion \
 # RUN:   | FileCheck %s
 
 # CHECK: lui_addi:%bb.0
@@ -174,3 +174,31 @@ body:             |
     $x11 = COPY %5
     PseudoRET
 ...
+
+# CHECK: shift_n_add:%bb.0
+# CHECK: Macro fuse: {{.*}}SLLI - ADD
+---
+name: shift_n_add
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+    $x10 = SLLI $x10, 1
+    $x12 = XORI $x11, 3
+    $x10 = ADD $x10, $x11
+    PseudoRET
+...
+
+# CHECK: shift_n_add_commutable:%bb.0
+# CHECK: Macro fuse: {{.*}}SLLI - ADD
+---
+name: shift_n_add_commutable
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+    $x10 = SLLI $x10, 1
+    $x12 = XORI $x11, 3
+    $x10 = ADD $x11, $x10
+    PseudoRET
+...
diff --git a/llvm/test/TableGen/MacroFusion.td b/llvm/test/TableGen/MacroFusion.td
index 4aa6c8d9acb273..c1cdd0fee78ccd 100644
--- a/llvm/test/TableGen/MacroFusion.td
+++ b/llvm/test/TableGen/MacroFusion.td
@@ -34,6 +34,11 @@ let Namespace = "Test" in {
 def Inst0 : TestInst<0>;
 def Inst1 : TestInst<1>;
 
+def BothFusionPredicate: BothFusionPredicateWithMCInstPredicate<CheckRegOperand<0, X0>>;
+def TestBothFusionPredicate: Fusion<"test-both-fusion-predicate", "HasBothFusionPredicate",
+                                    "Test BothFusionPredicate",
+                                    [BothFusionPredicate]>;
+
 def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
                              CheckOpcode<[Inst0]>,
                              CheckAll<[
@@ -45,6 +50,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 // CHECK-PREDICATOR-NEXT:  #undef GET_Test_MACRO_FUSION_PRED_DECL
 // CHECK-PREDICATOR-EMPTY:
 // CHECK-PREDICATOR-NEXT:  namespace llvm {
+// CHECK-PREDICATOR-NEXT:  bool isTestBothFusionPredicate(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
 // CHECK-PREDICATOR-NEXT:  bool isTestFusion(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
 // CHECK-PREDICATOR-NEXT:  } // end namespace llvm
 // CHECK-PREDICATOR-EMPTY:
@@ -54,6 +60,24 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 // CHECK-PREDICATOR-NEXT:  #undef GET_Test_MACRO_FUSION_PRED_IMPL
 // CHECK-PREDICATOR-EMPTY:
 // CHECK-PREDICATOR-NEXT:  namespace llvm {
+// CHECK-PREDICATOR-NEXT:  bool isTestBothFusionPredicate(
+// CHECK-PREDICATOR-NEXT:      const TargetInstrInfo &TII,
+// CHECK-PREDICATOR-NEXT:      const TargetSubtargetInfo &STI,
+// CHECK-PREDICATOR-NEXT:      const MachineInstr *FirstMI,
+// CHECK-PREDICATOR-NEXT:      const MachineInstr &SecondMI) {
+// CHECK-PREDICATOR-NEXT:    auto &MRI = SecondMI.getMF()->getRegInfo();
+// CHECK-PREDICATOR-NEXT:    {
+// CHECK-PREDICATOR-NEXT:      const MachineInstr *MI = FirstMI;
+// CHECK-PREDICATOR-NEXT:      if (MI->getOperand(0).getReg() != Test::X0)
+// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:    }
+// CHECK-PREDICATOR-NEXT:    {
+// CHECK-PREDICATOR-NEXT:      const MachineInstr *MI = &SecondMI;
+// CHECK-PREDICATOR-NEXT:      if (MI->getOperand(0).getReg() != Test::X0)
+// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:    }
+// CHECK-PREDICATOR-NEXT:    return true;
+// CHECK-PREDICATOR-NEXT:  }
 // CHECK-PREDICATOR-NEXT:  bool isTestFusion(
 // CHECK-PREDICATOR-NEXT:      const TargetInstrInfo &TII,
 // CHECK-PREDICATOR-NEXT:      const TargetSubtargetInfo &STI,
@@ -75,13 +99,15 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 // CHECK-PREDICATOR-NEXT:      if (( MI->getOpcode() != Test::Inst0 ))
 // CHECK-PREDICATOR-NEXT:        return false;
 // CHECK-PREDICATOR-NEXT:    }
-// CHECK-PREDICATOR-NEXT:    {
-// CHECK-PREDICATOR-NEXT:      const MachineInstr *MI = &SecondMI;
-// CHECK-PREDICATOR-NEXT:      if (!(
-// CHECK-PREDICATOR-NEXT:          MI->getOperand(0).getReg().isVirtual()
-// CHECK-PREDICATOR-NEXT:          || MI->getOperand(0).getReg() == MI->getOperand(1).getReg()
-// CHECK-PREDICATOR-NEXT:        ))
-// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:    if (!SecondMI.getOperand(0).getReg().isVirtual()) {                                                                                                                                                                                      
+// CHECK-PREDICATOR-NEXT:      if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(1).getReg()) {                                                                                                                                                              
+// CHECK-PREDICATOR-NEXT:        if (!SecondMI.getDesc().isCommutable())                                                                                                                                                                                              
+// CHECK-PREDICATOR-NEXT:          return false;
+// CHECK-PREDICATOR-NEXT:        unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
+// CHECK-PREDICATOR-NEXT:        if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
+// CHECK-PREDICATOR-NEXT:          if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
+// CHECK-PREDICATOR-NEXT:            return false;
+// CHECK-PREDICATOR-NEXT:      }
 // CHECK-PREDICATOR-NEXT:    }
 // CHECK-PREDICATOR-NEXT:    {
 // CHECK-PREDICATOR-NEXT:      Register FirstDest = FirstMI->getOperand(0).getReg();
@@ -90,8 +116,14 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 // CHECK-PREDICATOR-NEXT:    }
 // CHECK-PREDICATOR-NEXT:    if (!(FirstMI->getOperand(0).isReg() &&
 // CHECK-PREDICATOR-NEXT:          SecondMI.getOperand(1).isReg() &&
-// CHECK-PREDICATOR-NEXT:          FirstMI->getOperand(0).getReg() == SecondMI.getOperand(1).getReg()))
-// CHECK-PREDICATOR-NEXT:      return false;
+// CHECK-PREDICATOR-NEXT:          FirstMI->getOperand(0).getReg() == SecondMI.getOperand(1).getReg())) {
+// CHECK-PREDICATOR-NEXT:      if (!SecondMI.getDesc().isCommutable())
+// CHECK-PREDICATOR-NEXT:        return false;
+// CHECK-PREDICATOR-NEXT:      unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
+// CHECK-PREDICATOR-NEXT:      if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
+// CHECK-PREDICATOR-NEXT:        if (FirstMI->getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
+// CHECK-PREDICATOR-NEXT:          return false;
+// CHECK-PREDICATOR-NEXT:    }
 // CHECK-PREDICATOR-NEXT:    return true;
 // CHECK-PREDICATOR-NEXT:  }
 // CHECK-PREDICATOR-NEXT:  } // end namespace llvm
@@ -106,6 +138,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
 
 // CHECK-SUBTARGET:      std::vector<MacroFusionPredTy> TestGenSubtargetInfo::getMacroFusions() const {
 // CHECK-SUBTARGET-NEXT:   std::vector<MacroFusionPredTy> Fusions;
+// CHECK-SUBTARGET-NEXT:   if (hasFeature(Test::TestBothFusionPredicate)) Fusions.push_back(llvm::isTestBothFusionPredicate); 
 // CHECK-SUBTARGET-NEXT:   if (hasFeature(Test::TestFusion)) Fusions.push_back(llvm::isTestFusion);
 // CHECK-SUBTARGET-NEXT:   return Fusions;
 // CHECK-SUBTARGET-NEXT: }
diff --git a/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp b/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
index 78dcd4471ae747..1042dd9c2dbccf 100644
--- a/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
+++ b/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
@@ -152,8 +152,7 @@ void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
         << "if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest))\n";
     OS.indent(4) << "  return false;\n";
     OS.indent(2) << "}\n";
-  } else if (Predicate->isSubClassOf(
-                 "FirstFusionPredicateWithMCInstPredicate")) {
+  } else if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
     OS.indent(2) << "{\n";
     OS.indent(4) << "const MachineInstr *MI = FirstMI;\n";
     OS.indent(4) << "if (";
@@ -173,7 +172,7 @@ void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
 void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
                                                        PredicateExpander &PE,
                                                        raw_ostream &OS) {
-  if (Predicate->isSubClassOf("SecondFusionPredicateWithMCInstPredicate")) {
+  if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
     OS.indent(2) << "{\n";
     OS.indent(4) << "const MachineInstr *MI = &SecondMI;\n";
     OS.indent(4) << "if (";
@@ -183,9 +182,31 @@ void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
     OS << ")\n";
     OS.indent(4) << "  return false;\n";
     OS.indent(2) << "}\n";
+  } else if (Predicate->isSubClassOf("SameReg")) {
+    int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
+    int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
+
+    OS.indent(2) << "if (!SecondMI.getOperand(" << FirstOpIdx
+                 << ").getReg().isVirtual()) {\n";
+    OS.indent(4) << "if (SecondMI.getOperand(" << FirstOpIdx
+                 << ").getReg() != SecondMI.getOperand(" << SecondOpIdx
+                 << ").getReg()) {\n";
+
+    OS.indent(6) << "if (!SecondMI.getDesc().isCommutable())\n";
+    OS.indent(6) << "  return false;\n";
+
+    OS.indent(6) << "unsigned SrcOpIdx1 = " << SecondOpIdx
+                 << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
+    OS.indent(6)
+        << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
+    OS.indent(6) << "  if (SecondMI.getOperand(" << FirstOpIdx
+                 << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
+    OS.indent(6) << "    return false;\n";
+    OS.indent(4) << "}\n";
+    OS.indent(2) << "}\n";
   } else {
     PrintFatalError(Predicate->getLoc(),
-                    "Unsupported predicate for first instruction: " +
+                    "Unsupported predicate for second instruction: " +
                         Predicate->getType()->getAsString());
   }
 }
@@ -196,9 +217,8 @@ void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
   if (Predicate->isSubClassOf("FusionPredicateWithCode"))
     OS << Predicate->getValueAsString("Predicate");
   else if (Predicate->isSubClassOf("BothFusionPredicateWithMCInstPredicate")) {
-    Record *MCPred = Predicate->getValueAsDef("Predicate");
-    emitFirstPredicate(MCPred, PE, OS);
-    emitSecondPredicate(MCPred, PE, OS);
+    emitFirstPredicate(Predicate, PE, OS);
+    emitSecondPredicate(Predicate, PE, OS);
   } else if (Predicate->isSubClassOf("TieReg")) {
     int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
     int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
@@ -208,8 +228,19 @@ void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
                  << ").isReg() &&\n";
     OS.indent(2) << "      FirstMI->getOperand(" << FirstOpIdx
                  << ").getReg() == SecondMI.getOperand(" << SecondOpIdx
-                 << ").getReg()))\n";
-    OS.indent(2) << "  return false;\n";
+                 << ").getReg())) {\n";
+
+    OS.indent(4) << "if (!SecondMI.getDesc().isCommutable())\n";
+    OS.indent(4) << "  return false;\n";
+
+    OS.indent(4) << "unsigned SrcOpIdx1 = " << SecondOpIdx
+                 << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
+    OS.indent(4)
+        << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
+    OS.indent(4) << "  if (FirstMI->getOperand(" << FirstOpIdx
+                 << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
+    OS.indent(4) << "    return false;\n";
+    OS.indent(2) << "}\n";
   } else
     PrintFatalError(Predicate->getLoc(),
                     "Unsupported predicate for both instruction: " +
 | 
74c60af    to
    ac50115      
    Compare
  
    ac50115    to
    b426ef9      
    Compare
  
    | I added a  | 
| 
 Wouldn't that be implied by the opcode in the first place? The instruction will already have isCommutable set | 
| 
 We can't simply use the  class SingleFusion<string name, string fieldName, string desc,
                   Instruction firstInst, Instruction secondInst,
                   MCInstPredicate firstInstPred = TruePred,
                   MCInstPredicate secondInstPred = TruePred,
                   list<FusionPredicate> prolog = [],
                   list<FusionPredicate> epilog = []>
    : SimpleFusion<name, fieldName, desc,
                   CheckAll<!listconcat(
                              [CheckOpcode<[firstInst]>],
                              [firstInstPred])>,
                   CheckAll<!listconcat(
                              [CheckOpcode<[secondInst]>],
                              [secondInstPred])>,
                   prolog, epilog> {
  let IsCommutable = secondInst.isCommutable;
}
def ShiftNAddFusion
  : SingleFusion<"shift-n-add-fusion", "HasShiftNAddFusion",
                 "Enable SLLI+ADD to be fused to shift left by 1/2/3 and add",
                 SLLI, ADD,
                 CheckAny<[CheckImmOperand<2, 1>,
                           CheckImmOperand<2, 2>,
                           CheckImmOperand<2, 3>]>>; | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yet another commutable flag is a bit annoying/undiscoverable but OK I guess
If the second instruction is commutable, we should be able to check its commutable operands. A field `IsCommutable` is added to indicate whether we should generate code for checking commutable operands. Fixes llvm#82738
b426ef9    to
    d7954a8      
    Compare
  
    
If the second instruction is commutable, we should be able to check
its commutable operands.
A simple RISCV fusion is contained in this PR to show the functionality
is correct, I may remove it when landing.
Fixes #82738